{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6909d379-507b-400d-81ca-c16b97fe33f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pandas as pd\n",
    "from tqdm.auto import tqdm\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from elasticsearch import Elasticsearch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "56e62c2d-7bb2-4624-82fa-7cbba161e313",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('documents-with-ids.json', 'rt') as f_in:\n",
    "    documents = json.load(f_in)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6a05516c-29e7-4d8f-83ff-e22aab46285a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/python/3.10.13/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "model_name = 'multi-qa-MiniLM-L6-cos-v1'\n",
    "model = SentenceTransformer(model_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48e155b6-eb4e-4049-a81e-d25dd0047851",
   "metadata": {},
   "source": [
    "## Indexing stage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2cb482e5-b353-4601-b83c-27277f210c14",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "38ececb390834576aeb98f0145d9bf96",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/948 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for doc in tqdm(documents):\n",
    "    question = doc['question']\n",
    "    text = doc['text']\n",
    "    qt = question + ' ' + text\n",
    "\n",
    "    doc['question_vector'] = model.encode(question)\n",
    "    doc['text_vector'] = model.encode(text)\n",
    "    doc['question_text_vector'] = model.encode(qt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b97df08d-b011-49ab-8ff8-18cf2335c0e2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ObjectApiResponse({'acknowledged': True, 'shards_acknowledged': True, 'index': 'course-questions'})"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "es_client = Elasticsearch('http://localhost:9200') \n",
    "\n",
    "index_settings = {\n",
    "    \"settings\": {\n",
    "        \"number_of_shards\": 1,\n",
    "        \"number_of_replicas\": 0\n",
    "    },\n",
    "    \"mappings\": {\n",
    "        \"properties\": {\n",
    "            \"text\": {\"type\": \"text\"},\n",
    "            \"section\": {\"type\": \"text\"},\n",
    "            \"question\": {\"type\": \"text\"},\n",
    "            \"course\": {\"type\": \"keyword\"},\n",
    "            \"id\": {\"type\": \"keyword\"},\n",
    "            \"question_vector\": {\n",
    "                \"type\": \"dense_vector\",\n",
    "                \"dims\": 384,\n",
    "                \"index\": True,\n",
    "                \"similarity\": \"cosine\"\n",
    "            },\n",
    "            \"text_vector\": {\n",
    "                \"type\": \"dense_vector\",\n",
    "                \"dims\": 384,\n",
    "                \"index\": True,\n",
    "                \"similarity\": \"cosine\"\n",
    "            },\n",
    "            \"question_text_vector\": {\n",
    "                \"type\": \"dense_vector\",\n",
    "                \"dims\": 384,\n",
    "                \"index\": True,\n",
    "                \"similarity\": \"cosine\"\n",
    "            },\n",
    "        }\n",
    "    }\n",
    "}\n",
    "\n",
    "index_name = \"course-questions\"\n",
    "\n",
    "es_client.indices.delete(index=index_name, ignore_unavailable=True)\n",
    "es_client.indices.create(index=index_name, body=index_settings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cc16aaec-6c50-4599-accb-6014da69b57a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c70fef72c1024a908fa2f40c38222d54",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/948 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for doc in tqdm(documents):\n",
    "    es_client.index(index=index_name, document=doc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da8f33f2-7152-42cf-b246-179434bbee94",
   "metadata": {},
   "source": [
    "## Retrieval stage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "19cda3c6-ac39-4ad3-9f97-d90b7e2b2f0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.embeddings import SentenceTransformerEmbeddings\n",
    "from typing import Dict\n",
    "from langchain_elasticsearch import ElasticsearchRetriever"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5f17e7b4-81c4-4f04-acdc-60813969ab2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "es_url = 'http://localhost:9200'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2a4a460b-78ac-4b80-b298-283a4b1888eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "query = 'I just discovered the course. Can I still join it?'\n",
    "course = \"data-engineering-zoomcamp\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "f5bfe58e-b9aa-4eaa-985f-06b9c3ba725e",
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = SentenceTransformerEmbeddings(model_name=\"sentence-transformers/multi-qa-MiniLM-L6-cos-v1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "a28528a8-cbde-48e7-8737-46002a90b1b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def hybrid_query(search_query: str) -> Dict:\n",
    "    vector = embeddings.embed_query(search_query)  # same embeddings as for indexing\n",
    "    return {\n",
    "        \"query\": {\n",
    "            \"bool\": {\n",
    "                \"must\": {\n",
    "                    \"multi_match\": {\n",
    "                        \"query\": search_query,\n",
    "                        \"fields\": [\"question\", \"text\", \"section\"],\n",
    "                        \"type\": \"best_fields\",\n",
    "                        \"boost\": 0.5,\n",
    "                    }\n",
    "                },\n",
    "                \"filter\": {\n",
    "                    \"term\": {\n",
    "                        \"course\": course\n",
    "                    }\n",
    "                }\n",
    "            }\n",
    "        },\n",
    "        \"knn\": {\n",
    "            \"field\": \"question_text_vector\",\n",
    "            \"query_vector\": vector,\n",
    "            \"k\": 5,\n",
    "            \"num_candidates\": 10000,\n",
    "            \"boost\": 0.5,\n",
    "            \"filter\": {\n",
    "                \"term\": {\n",
    "                    \"course\": course\n",
    "                }\n",
    "            }\n",
    "        },\n",
    "        \"size\": 5,\n",
    "        # \"rank\": {\"rrf\": {}},\n",
    "    }\n",
    "\n",
    "\n",
    "hybrid_retriever = ElasticsearchRetriever.from_es_params(\n",
    "    index_name=index_name,\n",
    "    body_func=hybrid_query,\n",
    "    content_field='text',\n",
    "    url=es_url,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "b7522a18-0238-404e-bf0c-d389e6a95b54",
   "metadata": {},
   "outputs": [],
   "source": [
    "hybrid_results = hybrid_retriever.invoke(query)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "748ae321-dda7-4e16-a3e6-be30d71cafe3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Course - Can I still join the course after the start date? data-engineering-zoomcamp 12.559245\n",
      "Course - Can I follow the course after it finishes? data-engineering-zoomcamp 9.39959\n",
      "Course - What can I do before the course starts? data-engineering-zoomcamp 7.306914\n",
      "Course - Can I get support if I take the course in the self-paced mode? data-engineering-zoomcamp 7.1085525\n",
      "Course - When will the course start? data-engineering-zoomcamp 6.7513986\n"
     ]
    }
   ],
   "source": [
    "for result in hybrid_results:\n",
    "    print(result.metadata['_source']['question'], result.metadata['_source']['course'], result.metadata['_score'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d097e43-0621-46aa-955f-dc5f10407048",
   "metadata": {},
   "source": [
    "### Hybrid search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "efe387cd-ceef-4f56-bc3c-908b518cef89",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ground_truth = pd.read_csv('ground-truth-data.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "49acd0a7-6c42-4030-97e0-5e65d096cdbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "ground_truth = df_ground_truth.to_dict(orient='records')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "cc01b7c9-033b-42a5-abb3-45a2cc3468a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def hit_rate(relevance_total):\n",
    "    cnt = 0\n",
    "\n",
    "    for line in relevance_total:\n",
    "        if True in line:\n",
    "            cnt = cnt + 1\n",
    "\n",
    "    return cnt / len(relevance_total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "90f27865-25f6-42ac-b8f0-bc2c9d595560",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mrr(relevance_total):\n",
    "    total_score = 0.0\n",
    "\n",
    "    for line in relevance_total:\n",
    "        for rank in range(len(line)):\n",
    "            if line[rank] == True:\n",
    "                total_score = total_score + 1 / (rank + 1)\n",
    "\n",
    "    return total_score / len(relevance_total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "96963309-7cfd-4190-b4b6-2a41befda10a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def elastic_search_hybrid(field, query, course):\n",
    "    def hybrid_query(search_query: str) -> Dict:\n",
    "        vector = embeddings.embed_query(search_query)  # same embeddings as for indexing\n",
    "        return {\n",
    "            \"query\": {\n",
    "                \"bool\": {\n",
    "                    \"must\": {\n",
    "                        \"multi_match\": {\n",
    "                            \"query\": search_query,\n",
    "                            \"fields\": [\"question\", \"text\", \"section\"],\n",
    "                            \"type\": \"best_fields\",\n",
    "                            \"boost\": 0.5,\n",
    "                        }\n",
    "                    },\n",
    "                    \"filter\": {\n",
    "                        \"term\": {\n",
    "                            \"course\": course\n",
    "                        }\n",
    "                    }\n",
    "                }\n",
    "            },\n",
    "            \"knn\": {\n",
    "                \"field\": field,\n",
    "                \"query_vector\": vector,\n",
    "                \"k\": 5,\n",
    "                \"num_candidates\": 10000,\n",
    "                \"boost\": 0.5,\n",
    "                \"filter\": {\n",
    "                    \"term\": {\n",
    "                        \"course\": course\n",
    "                    }\n",
    "                }\n",
    "            },\n",
    "            \"size\": 5,\n",
    "            \"_source\": [\"text\", \"section\", \"question\", \"course\", \"id\"],\n",
    "            # \"rank\": {\"rrf\": {}},\n",
    "        }\n",
    "    \n",
    "    \n",
    "    hybrid_retriever = ElasticsearchRetriever.from_es_params(\n",
    "        index_name=index_name,\n",
    "        body_func=hybrid_query,\n",
    "        content_field='text',\n",
    "        url=es_url,\n",
    "    )\n",
    "\n",
    "    hybrid_results = hybrid_retriever.invoke(query)\n",
    "    \n",
    "    result_docs = []\n",
    "    \n",
    "    for hit in hybrid_results:\n",
    "        result_docs.append(hit.metadata['_source'])\n",
    "\n",
    "    return result_docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "9524200d-a3c3-4c89-b279-30fddc73c464",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'question': 'When does the course begin?',\n",
       " 'course': 'data-engineering-zoomcamp',\n",
       " 'document': 'c02e79ef'}"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ground_truth[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "2799e5de-2fb6-4a9b-b64f-0bebf1be04bc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'section': 'General course-related questions',\n",
       "  'question': 'Course - When will the course start?',\n",
       "  'course': 'data-engineering-zoomcamp',\n",
       "  'id': 'c02e79ef'},\n",
       " {'section': 'General course-related questions',\n",
       "  'question': 'Course - Can I still join the course after the start date?',\n",
       "  'course': 'data-engineering-zoomcamp',\n",
       "  'id': '7842b56a'},\n",
       " {'section': 'General course-related questions',\n",
       "  'question': 'Course - Can I follow the course after it finishes?',\n",
       "  'course': 'data-engineering-zoomcamp',\n",
       "  'id': 'a482086d'},\n",
       " {'section': 'Module 1: Docker and Terraform',\n",
       "  'question': 'PGCLI - error column c.relhasoids does not exist',\n",
       "  'course': 'data-engineering-zoomcamp',\n",
       "  'id': 'c91ad8f2'},\n",
       " {'section': 'General course-related questions',\n",
       "  'question': 'Course - What are the prerequisites for this course?',\n",
       "  'course': 'data-engineering-zoomcamp',\n",
       "  'id': '1f6520ca'}]"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "question = ground_truth[0]['question']\n",
    "course = ground_truth[0]['course']\n",
    "elastic_search_hybrid('question_text_vector', question, course)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "674aa620-d80f-4e38-8d28-406adcb94aae",
   "metadata": {},
   "outputs": [],
   "source": [
    "def question_text_hybrid(q):\n",
    "    question = q['question']\n",
    "    course = q['course']\n",
    "\n",
    "    return elastic_search_hybrid('question_text_vector', question, course)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "4b4882cf-6ee5-4b13-8363-0a08eb0a78fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(ground_truth, search_function):\n",
    "    relevance_total = []\n",
    "\n",
    "    for q in tqdm(ground_truth):\n",
    "        doc_id = q['document']\n",
    "        results = search_function(q)\n",
    "        relevance = [d['id'] == doc_id for d in results]\n",
    "        relevance_total.append(relevance)\n",
    "\n",
    "    return {\n",
    "        'hit_rate': hit_rate(relevance_total),\n",
    "        'mrr': mrr(relevance_total),\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "d3ba9f7a-83b5-4df0-9675-823f4c45246c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a46ae3020d9e45d0af918999373bf86a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4627 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'hit_rate': 0.9250054030689432, 'mrr': 0.8506231539514445}"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluate(ground_truth, question_text_hybrid)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "beccbfaa-15cd-42d0-9f14-c67fe4b751e1",
   "metadata": {},
   "source": [
    "Hybrid search with ES: `{'hit_rate': 0.9250054030689432, 'mrr': 0.8506231539514445}`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7fc1926d",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
